Amazon SageMakerで特定のGitリポジトリに含まれるTensorFlowモデルを学習させる
どうも、DA事業本部の大澤です。
SageMaker Python SDKのスクリプトを見ている時にFramework用Estimatorにgit_config
なる引数があることを見つけました。
TensorFlowやMXNet等といったSageMakerがデフォルトで対応しているフレームワークに限られますが、git_config
にリポジトリ情報を指定することで、そのリポジトリのスクリプトをエントリポイントとして指定し、学習に用いることができます。
SageMaker ExamplesにあるTensorFlow用のサンプルノートブックでgit_config
が利用されていたので、試してみました。今回はその内容をお伝えします。
やってみる
amazon-sagemaker-examples/char-rnn-tensorflowでSherlock Holmesのテキストデータを学習させます。
※ サンプルノートブックに記載されている内容のうち、ローカルモードの部分は飛ばして、SageMakerで学習させる部分のみ紹介します
データ準備
まず学習に使用するSherlock Holmesのテキストデータをダウンロードします。
import os data_dir = os.path.join(os.getcwd(), 'sherlock') os.makedirs(data_dir, exist_ok=True) !wget https://sherlock-holm.es/stories/plain-text/cnus.txt --force-directories --output-document=sherlock/input.txt
git_config
続いて学習に使用するスクリプトが含まれるリポジトリをクローンするためのgit_config
を設定します。
git_config
を設定することで、学習開始時に自動的に設定に従ってリポジトリをクローンし、source_dirに指定したディレクトリがS3にアップロードされ、学習に利用されます。
git_config
には次のように辞書形式で設定を格納します。
git_config = {'repo': 'https://github.com/awslabs/amazon-sagemaker-examples.git', 'branch': 'training-scripts'}
git_config
では次のようなパラメータを利用できます。
名前 | 必須 | デフォルト | 説明 | 備考 |
---|---|---|---|---|
repo | o | リポジトリのURI | HTTPS/SSHどちらでもOK | |
branch | master | ブランチ名 | ||
commit | コミットのハッシュ | 未指定の場合は対象ブランチの最新コミットがクローンされる | ||
2FA_enabled | false | 2要素認証が有効かどうか | HTTPSの場合のみ有効 | |
username | ユーザ名 | HTTPSの場合のみ有効 | ||
password | パスワード | HTTPSの場合のみ有効 | ||
token | アクセストークン | HTTPSの場合のみ有効 |
git_config
で設定したリポジトリのクローンはローカル環境で実行されます。従って、SSH接続時の設定などはローカル環境に設定されているものが利用されます。
詳細な解説や使用例についてはドキュメントをご覧ください。
ハイパーパラメータ
学習に使用するはハイパーパラメータです。
hyperparameters = {'num_epochs': 1, 'data_dir': '/opt/ml/input/data/training'}
ハイパーパラメータとして設定した内容は学習実行時に次のようにスクリプトに渡されます。model_dirは設定の必要がなく、自動的に付与されます。
python train.py --num-epochs 1 --data_dir /opt/ml/input/data/training --model_dir /opt/ml/model
データアップロード
先ほどダウンロードしたデータをS3にアップロードします。inputsにはアップロード先のS3 URIが格納されます。
※ sagemaker.Session.upload_data
でバケット名を指定しない場合、sagemaker-{region}-{accountid}
という名前のバケットが自動作成されて利用されます。
import sagemaker inputs = sagemaker.Session().upload_data(path='sherlock', key_prefix='datasets/sherlock')
学習
TensorFlow用のEstimatorを用いて、学習の設定を行います。
先ほど設定したgit_config
やhyperparameters
に加えて、使用するインスタンスや使用するIAMロールなども設定します。必要に応じて設定内容は変更してください。
今回はgit_config
を設定するので、source_dir
は相対パスで設定する必要があります。次のような設定内容であればgit_config
に設定したリポジトリのディレクトリchar-rnn-tensorflow
が圧縮されてS3にアップロードされます。
引数や詳細な使い方についてはドキュメントをご覧ください。
estimator = TensorFlow(entry_point='train.py', source_dir='char-rnn-tensorflow', git_config=git_config, train_instance_type='ml.c4.xlarge', # Executes training in a ml.c4.xlarge instance train_instance_count=1, hyperparameters=hyperparameters, role=sagemaker.get_execution_role(), # 必要に応じて使用するIAMロールのARNを記載する framework_version='1.14', py_version='py3', script_mode=True) estimator.fit({'training': inputs})
次のように学習が開始されます。(リポジトリのクローン時のログは表示されません。)
...
S3にアップロードされた圧縮ファイルのパスは次のようにハイパーパラメータとして自動的に設定されているため、後から確認可能です。
estimator.hyperparameters()['sagemaker_submit_directory']
さいごに
SageMaker Python SDKのFramework用Estimatorのgit_config
を使ってみた様子を紹介しました。今回はTensorFlowでの使い方を紹介しましたが、MXNetやPyTorchなどSageMakerがデフォルトで対応している他のフレームワークでも同様に使用することができます。オープンソースのリポジトリや別リポジトリで管理している学習用スクリプトを利用する場合には便利そうです。